{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `FoldBatchnormToConv2D`\n",
    "\n",
    "参考：`tvm/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import relax\n",
    "from tvm.script import relax as R\n",
    "from tvm.script import ir as I\n",
    "from tvm.script.ir_builder import IRBuilder\n",
    "from tvm.ir.module import IRModule\n",
    "from tvm.script.ir_builder import relax as relax_builder\n",
    "from tvm.relax.expr_functor import PyExprVisitor, visitor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_conv2d_batchnorm_sample():\n",
    "    with IRBuilder() as builder:\n",
    "        with relax_builder.function():\n",
    "            R.func_name(\"main\")\n",
    "            data = R.arg(\"data\", R.Tensor((1, 3, 224, 224), \"float32\"))\n",
    "            weight = R.arg(\"weight\", R.Tensor((32, 3, 3, 3), \"float32\"))\n",
    "            with R.dataflow() as frame:\n",
    "                output = R.emit(\n",
    "                    R.nn.conv2d(\n",
    "                        data,\n",
    "                        weight,\n",
    "                        out_dtype=\"float32\",\n",
    "                        strides=(1, 1),\n",
    "                        dilation=(1, 1),\n",
    "                        padding=(1, 1),\n",
    "                        data_layout=\"NCHW\",\n",
    "                        kernel_layout=\"OIHW\",\n",
    "                        groups=1,\n",
    "                    )\n",
    "                )\n",
    "                gamma = R.arg(\"gamma\", R.Tensor((32,), \"float32\"))\n",
    "                beta = R.arg(\"beta\", R.Tensor((32,), \"float32\"))\n",
    "                mean = R.arg(\"mean\", R.Tensor((32,), \"float32\"))\n",
    "                variance = R.arg(\"variance\", R.Tensor((32,), \"float32\"))\n",
    "                output = R.emit(\n",
    "                    R.nn.batch_norm(output, gamma, beta, mean, variance, axis=1, epsilon=1e-5)[0]\n",
    "                )\n",
    "                R.output(output)\n",
    "\n",
    "            R.func_ret_value(frame.output_vars[0])\n",
    "\n",
    "    func = builder.get()\n",
    "\n",
    "    return tvm.IRModule({\"main\": func})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 验证 BatchNorm+Conv2D 融合的正确性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = tvm.target.Target(\"llvm\", host=\"llvm\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建测试模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = get_conv2d_batchnorm_sample()  # 原始模型\n",
    "mod_fold = get_conv2d_batchnorm_sample()  # 用于优化的副本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "准备测试数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成随机输入数据（1张3通道224x224图片）\n",
    "data_in = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype(np.float32))\n",
    "\n",
    "# 生成模型参数（Conv2D权重和BatchNorm参数）\n",
    "weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32))\n",
    "gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32))\n",
    "beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32))\n",
    "mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32))\n",
    "variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数绑定："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将随机生成的参数绑定到两个模型\n",
    "params_np = {\n",
    "    \"weight\": weight_data,\n",
    "    \"gamma\": gamma_data,\n",
    "    \"beta\": beta_data,\n",
    "    \"mean\": mean_data,\n",
    "    \"variance\": variance_data,\n",
    "}\n",
    "mod = tvm.relax.transform.BindParams(\"main\", params_np)(mod)\n",
    "mod_fold = tvm.relax.transform.BindParams(\"main\", params_np)(mod_fold)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基准模型执行："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 原始模型处理流程\n",
    "mod = tvm.relax.transform.DecomposeOpsForInference()(mod)  # 算子分解\n",
    "ex = tvm.compile(mod, target)  # 编译\n",
    "vm = relax.VirtualMachine(ex, tvm.cpu())  # 创建虚拟机\n",
    "out = vm[\"main\"](data_in)  # 执行推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)\n",
       "            lv_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>expand_dims(metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">1</span>], axis<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>])\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(lv, lv_1)\n",
       "            lv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>expand_dims(metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">2</span>], axis<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>])\n",
       "            lv3: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(lv2, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">9.9999997473787516e-06</span>, <span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            lv4: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sqrt(lv3)\n",
       "            lv5: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>divide(lv1, lv4)\n",
       "            lv6: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>expand_dims(metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">3</span>], axis<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>])\n",
       "            lv7: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv5, lv6)\n",
       "            lv8: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>expand_dims(metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">4</span>], axis<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>])\n",
       "            lv9: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(lv7, lv8)\n",
       "            lv1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">32</span>,), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">32</span>,), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">=</span> lv9, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">1</span>], metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">2</span>]\n",
       "            lv2_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> lv1_1[<span style=\"color: #008000\">0</span>]\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(lv2_1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> lv2_1\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "折叠 BN 到 Conv2D："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 应用BatchNorm折叠优化器\n",
    "mod_fold = relax.transform.FoldBatchnormToConv2D()(mod_fold)\n",
    "mod_fold = relax.transform.FoldConstant()(mod_fold)\n",
    "ex_fold = tvm.compile(mod_fold, target)\n",
    "vm_fold = relax.VirtualMachine(ex_fold, tvm.cpu())\n",
    "out_fold = vm_fold[\"main\"](data_in)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "结果验证："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 比较优化前后结果是否一致（容差1e-5）\n",
    "tvm.testing.assert_allclose(out.numpy(), out_fold.numpy(), rtol=1e-5, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv5: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)\n",
       "            lv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(lv5, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">1</span>])\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(lv2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> lv2\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod_fold.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
